# from proxy import ProxyTest
# from proxytest_mul import ProxyTest_mul
from proxytest_mulx import ProxyTest_mulx
import numpy as np
import pandas as pd
from functools import partial
from binn import quantile_bin,uniform_bin
from tools import _save
from itertools import product
import multiprocess as mp
from tqdm import tqdm
from data import Experiment1,Experiment2,Experiment3,Experiment4,Experiment5,CoupleData,CoupleData1

def discretize(levelx,levelu,levely,data):
    '''
    Bounded error discretization, we recommend use quantile_bin for X only (when a moderate sample size is available)
    '''
    tx = quantile_bin(data.A,levelx)
    tu = quantile_bin(data.U,levelu)
    tw = quantile_bin(data.W,levelu)
    ty = quantile_bin(data.Y,levely)
    bindata = np.stack([tx,tu,tw,ty],axis=1)
    return bindata


def _partialtest(tuple_of_hyper):
    size, levela = tuple_of_hyper
    h1,h0 = partialtest(size=size, levela=levela)
    return [h1,h0]


def test(size, levela, levelu, levely, num_trials):
    powers,p_list = list(),list()
    for causal in [True, False]:
        sumproxy = 0; trial = 0; p=0
        while trial < num_trials:
            data = CoupleData(size)
            ta = quantile_bin(data.A1,levela)
            tu = quantile_bin(data.U,levelu)
            tw = quantile_bin(data.A2,levelu)
            ty = quantile_bin(data.Y1,levely)
            bindata = np.stack([ta,tu,tw,ty],axis=1)
            bindf = pd.DataFrame(bindata, columns=['A', 'U', 'W', 'Y'])
            try:   
                tester = ProxyTest_mulx(bindf, levela, levely,levelu)
                pproxy = tester._proxytest()
                p += pproxy
            except AssertionError:
                continue
            if pproxy < 0.05:
                sumproxy += 1
            trial += 1
        p_list.append(p/num_trials)
        powers.append(sumproxy / num_trials)
    return p_list,powers





if __name__ == '__main__':
    # parameters
    levela = [15]; levely = 5; rangeu = 8; rangesize = [400,600,800,1000,1200,1400,1600]
    num_trials = 50; num_params = 10
    meta_params = {'levelx': list(levela), 'levely': levely, 'rangeu': rangeu,
                    'rangesize': list(rangesize),
                    'num_trials': num_trials}
    
    _save(meta_params, 'params.json')


    partialtest = partial(test, levelu=rangeu, levely=levely, num_trials=num_trials)
    tuple_of_hypers = product(rangesize, levela)

    with mp.Pool(9) as p:
        record = list(tqdm(iterable=(p.imap(_partialtest, tuple_of_hypers)), total=len(rangesize) * len(levela)))
        _save(record, 'record.json')
        
